rm(list=ls())
library(fields); library(maps); library(ncdf4);library(abind)
library(geosphere);library(mapdata)
setwd('/Volumes/Disk2/JJA_ozone_MAM/Harvard_Dataverse')

#----- load functions -----------------------
source("./Function/get_geo.R")
source("./Function/get_met.R")
source('./Function/read_detrend.R')
source('./Function/eof.mca.R')
source('./Function/cov4gappy.R')
source('./Function/Henderson.R')
source('./Function/filled.contour3.R')

col.gen = colorRampPalette(c("#0563FA", "white","#ff4d4d"), space="rgb")
levels=seq(-0.7, 0.7, length=61)
mid=(levels[1:(length(levels)-1)]+levels[2:(length(levels))])/2
cols=col.gen(length(levels)-1)
if(min(abs(mid))>0){
	ind1=which(mid<=0);ind2=which(mid>=0)
	ind=abind(Reduce(intersect, list(ind1+1, ind2)), Reduce(intersect, list(ind1, ind2-1)))
	ind=sort(unique(ind))
  	 cols[ind]="white"
}	


name="./Data/land.nc"
datafile = nc_open(name)
lon.land = ncvar_get(datafile, varid='lon')-360
lat.land = ncvar_get(datafile, varid='lat')
lat.land=rev(lat.land)
land = ncvar_get(datafile, varid='land')
land=land[,length(lat.land):1]
nc_close(datafile)

#=========== Part 1: read SST ================================
read_HadISST=function(name,start_month,end_month,xlim=c(-180,0),ylim=c(-10,70)){
date=array(NA,c(1741,2))
for (i in 1:1741){
  if (i%%12>0) {date[i,1]=1870+floor(i/12);date[i,2]=i%%12}
  if (i%%12==0){date[i,1]=1870+floor(i/12)-1;date[i,2]=12}
}
datadir='./Data/SST/'
datafile = nc_open(paste(datadir,name,sep=""))
lon = ncvar_get(datafile, varid='longitude')
lat = ncvar_get(datafile, varid='latitude')
time= ncvar_get(datafile, varid='time')
ind1=which(lon>=xlim[1] & lon<=xlim[2])
ind2=which(lat>=ylim[1] & lat<=ylim[2])
time.ind=which(date[,2]>= 1 & date[,2]<=12 & date[,1]<=2008 & date[,1]>=1979)
origin_data = ncvar_get(datafile,start=c(ind1[1],ind2[1], time.ind[1]),count=c(length(ind1),length(ind2),length(time.ind)),varid='sst')
date=date[time.ind,]
lon=lon[ind1]
lat=lat[ind2]
lat=rev(lat)
SST=origin_data[,length(lat):1,]
nc_close(datafile)

temp=array(NA,c(length(lon),length(lat), 2008-1979+1))
for (year in 1979: 2008){
  ind=(date[,1]==year & date[,2]>=start_month & date[,2]<=end_month)
  temp[,,year-1978]=apply(SST[,,ind],c(1,2),mean,na.rm=TRUE)
}
return(list('data'=temp,'lon'=lon,'lat'=lat,'date'=1979: 2008))
}

xlimit=c(180,360)
ylimit=c(0,70)
MAM_SSTs= read_HadISST("HadISST_sst.nc",3,5,xlim=xlimit-360,ylim=ylimit)
JJA_SSTs= read_HadISST("HadISST_sst.nc",6,8,xlim=xlimit-360,ylim=ylimit)

dev.new(width=6,height=6)
#######
mai=c(0.1, 0.05, 0.2, 0.05)
mgp=c(1.2, 0.4, 0)
tcl=-0.2
ps=12

close.screen(all.screens = TRUE)
m <- rbind(c(0,1, 0.93, 1),
				  c(0,    0.50,  0.7, 0.95), 
				  c(0.5, 1.0,    0.7, 0.95), 
				  c(0,    0.50,  0.45, 0.7), 
				  c(0.5, 1.0,    0.45, 0.7), 	
			      c(0,    0.50,  0.2, 0.45), 
				  c(0.5, 1.0,    0.2, 0.45), 				  			  
				  c(0.2,0.8,0.15,0.45))
split.screen(m)

screen(1)
par(mai=mai, mgp=mgp, tcl=tcl, ps=ps)
text(x=0.5,y=0.53,"Correlations of JJA surface temperatures in the eastern US with SSTs", font=1, cex=1.1)

#========Part 2: NCEP Reanlysis =================================
surf_T=read_surface('air.mon.mean.nc',6,8)
lon.frame=surf_T$lon-360
lat.frame=surf_T$lat
yearly_O3=surf_T$data
xx=1948:2013
time.ind=(xx>=1979 & xx<=2008)
ind1=(lon.frame>=-100 & lon.frame<=-65)
ind2=(lat.frame>=32.5 & lat.frame<=50)
# plot.field(US.area,lon.frame,lat.frame)
US.mask=sp.dissolve(land,lon.land,lat.land,lon.frame,lat.frame)
US.area=cal.area(lon.frame,lat.frame)
US.area[!US.mask]=NA
SE.O3=cal.area_mean(yearly_O3,US.area,ind1,ind2)
SE.O3=mov.detrend(SE.O3[time.ind])

index=2
title_names=c('(a) NCEP: East-JJA-T vs. MAM SST ', '(b) NCEP: East-JJA-T vs. JJA SST ')
for (month in c(4,7)){
screen(index)	
# met=read_surface('air.mon.mean.nc',month-1,month+1,xlim=xlimit,ylim=ylimit)
met=read_SST('sst.mnmean.nc', month-1, month+1,xlim=xlimit,ylim=ylimit)
# if(month==4) met=MAM_SSTs
# if(month==7) met=JJA_SSTs
lon=met$lon
lat=met$lat
month_SST=array(NA,dim(met$data[,,time.ind]))
for (i in 1:dim(month_SST)[1]){
	for (j in 1:dim(month_SST)[2]){
		month_SST[i,j,]=mov.detrend(met$data[i,j, time.ind])
	}
}

ap=find.cor2(month_SST,SE.O3,p_value=1)
ap[ap>=0.7]=0.7;ap[ap<=-0.7]=-0.7
par(mai=mai, mgp=mgp, tcl=tcl, ps=ps)
# image(lon-360,lat, ap,zlim=c(-0.7,0.7),col=rwb.colors(32),xaxt='n',yaxt='n',xlab='',ylab='');
filled.contour3(lon-360,lat,ap, levels=levels,  plot.axes=FALSE, col=cols)
map("world", add = T)	
cr=find.Rlim(0.05,sum(time.ind))
clines=contourLines(lon-360, lat, ap,level=c(-cr, cr))
for(k in 1:length(clines)){
	lines(clines[[k]]$x,clines[[k]]$y,lwd=1,col=1,lty=2)
}
index=index+1
title(title_names[(month-1)/3],cex.main=1, font.main=1)
}

#========Part 2: NOAA 20C Reanlysis =================================
surf_T=read_20c('air.sfc.mon.mean.nc',6,8)
lon.frame=surf_T$lon-360
lat.frame=surf_T$lat
yearly_O3=surf_T$data
xx=surf_T$date
time.ind=(xx>=1979 & xx<=2008)
ind1=(lon.frame>=-100 & lon.frame<=-75)
ind2=(lat.frame>=32.5 & lat.frame<=47.5)
US.mask=sp.dissolve(land,lon.land,lat.land,lon.frame,lat.frame)
US.area=cal.area(lon.frame,lat.frame)
US.area[!US.mask]=NA
SE.O3=cal.area_mean(yearly_O3,US.area,ind1,ind2)
SE.O3=mov.detrend(SE.O3[time.ind])

title_names=c('(c) 20CR: East-JJA-T vs. MAM SST ', '(d) 20CR: East-JJA-T vs. JJA SST')
for (month in c(4,7)){
screen(index)
 if(month==4) met=MAM_SSTs
 if(month==7) met=JJA_SSTs
lon=met$lon+360
lat=met$lat
month_SST=array(NA,dim(met$data))
for (i in 1:dim(month_SST)[1]){
	for (j in 1:dim(month_SST)[2]){
		month_SST[i,j,]=mov.detrend(met$data[i,j,])
	}
}


ap=find.cor2(month_SST,SE.O3,p_value=1)
ap[ap>=0.7]=0.7;ap[ap<=-0.7]=-0.7
par(mai=mai, mgp=mgp, tcl=tcl, ps=ps)
filled.contour3(lon-360,lat,ap, levels=levels,  plot.axes=FALSE, col=cols)
map("world", add = T)	
cr=find.Rlim(0.05,sum(time.ind))
clines=contourLines(lon-360, lat, ap,level=c(-cr, cr))
for(k in 1:length(clines)){
	lines(clines[[k]]$x,clines[[k]]$y,lwd=1,col=1,lty=2)
}
index=index+1
title(title_names[(month-1)/3],cex.main=1, font.main=1)
}

#========== Part 4: AMIP ================================
workdir='./Data/AMIP/'
model.names=list.dirs(path=workdir,full.names=FALSE,recursive=FALSE)
for(kk in 1:length(model.names)){
print('=====================')
print(kk)
datadir=(paste(workdir,model.names[kk],sep=''))
ss=load(paste(datadir,'/correlation_T-SST.Rdata',sep=""))
lon=lon.out
lat=lat.out
#ap=sp.dissolve.3D(correlation[,,c(1,2)],lon.out,lat.out,lon,lat)
ap=correlation
if(kk==1){
	MAM_correlation=ap[,,1]
	JJA_correlation=ap[,,2]	
} else {
	MAM_correlation =abind(MAM_correlation,ap[,,1],along=3)
	JJA_correlation =abind(JJA_correlation,ap[,,2],along=3)	
}
}


mask=sp.dissolve(land,lon.land,lat.land,lon,lat)
mask.ind=(mask>=0.25)
screen(6)
ap=apply(MAM_correlation,c(1,2),median,na.rm=TRUE)
ap[ap>=0.7]=0.7;ap[ap<=-0.7]=-0.7
ap[mask.ind]=NA
par(mai=mai, mgp=mgp, tcl=tcl, ps=ps)
# image(lon,lat, ap,zlim=c(-0.7,0.7),col=rwb.colors(32),xaxt='n',yaxt='n',xlab='',ylab='');
ind=(lat>=ylimit[1] & lat<=ylimit[2])
filled.contour3(lon,lat[ind],ap[,ind], levels=levels,  plot.axes=FALSE, col=cols)
map("world", add = T)	
spdata=apply(MAM_correlation>0,c(1,2),sum,na.rm=TRUE)/dim(JJA_correlation)[3]
spdata[mask.ind]=NA
clines=contourLines(lon, lat, spdata,level=c(0.3, 0.7))
for(k in 1:length(clines)){
	lines(clines[[k]]$x,clines[[k]]$y,lwd=1,col=1,lty=2)
}
title('(e) AMIP: East-JJA-T vs. MAM SST',cex.main=1, font.main=1)

screen(7)
ap=apply(JJA_correlation,c(1,2),median,na.rm=TRUE)
ap[ap>=0.7]=0.7;ap[ap<=-0.7]=-0.7
ap[mask.ind]=NA
par(mai=mai, mgp=mgp, tcl=tcl, ps=ps)
# image(lon,lat, ap,zlim=c(-0.7,0.7),col=rwb.colors(32),xaxt='n',yaxt='n',xlab='',ylab='');
ind=(lat>=ylimit[1] & lat<=ylimit[2])
filled.contour3(lon,lat[ind],ap[,ind], levels=levels,  plot.axes=FALSE, col=cols)
map("world", add = T)	
spdata=apply(JJA_correlation>0,c(1,2),sum,na.rm=TRUE)/dim(JJA_correlation)[3]
spdata[mask.ind]=NA
clines=contourLines(lon, lat, spdata,level=c(0.3, 0.7))
for(k in 1:length(clines)){
	if(length(clines[[k]]$x)>=20) lines(clines[[k]]$x,clines[[k]]$y,lwd=1,col=1,lty=2)
}
title('(f) AMIP: East-JJA-T vs. JJA SST',cex.main=1, font.main=1)

screen(8)
mai=c(0.1, 0.1, 0.05, 0.2)
par(mai=mai, mgp=mgp, tcl=tcl, ps=9)
image.plot(spdata,legend.shrink=0.9,legend.only=TRUE,zlim=c(-0.7,0.7),col=cols,horizontal=TRUE,legend.width=2.5,axis.args = list(mgp = c(0, 0.5, 0),tcl=-0.2,padj=-1))
text("r",x=par("usr")[2]-0.22,y=par("usr")[3]+0.165,srt=0, adj = 0,xpd=TRUE,cex=1.8, font=1)